{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2d56fa5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构建一个玩具数据集\n",
    "corpus = [ \"我喜欢吃苹果\",\n",
    "        \"我喜欢吃香蕉\",\n",
    "        \"她喜欢吃葡萄\",\n",
    "        \"他不喜欢吃香蕉\",\n",
    "        \"他喜欢吃苹果\",\n",
    "        \"她喜欢吃草莓\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f31f3204",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义一个分词函数，将文本转换为单个字符的列表\n",
    "def tokenize(text):\n",
    " return [char for char in text] # 将文本拆分为字符列表\n",
    "# 对每个文本进行分词，并打印出对应的单字列表\n",
    "print(\"单字列表:\") \n",
    "for text in corpus:\n",
    "    tokens = tokenize(text)\n",
    "    print(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "93cb1fa8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bigram 词频：\n",
      "我: {'喜': 2}\n",
      "喜: {'欢': 6}\n",
      "欢: {'吃': 6}\n",
      "吃: {'苹': 2, '香': 2, '葡': 1, '草': 1}\n",
      "苹: {'果': 2}\n",
      "香: {'蕉': 2}\n",
      "她: {'喜': 2}\n",
      "葡: {'萄': 1}\n",
      "他: {'不': 1, '喜': 1}\n",
      "不: {'喜': 1}\n",
      "草: {'莓': 1}\n"
     ]
    }
   ],
   "source": [
    "# 定义计算 N-Gram 词频的函数\n",
    "from collections import defaultdict, Counter # 导入所需库\n",
    "def count_ngrams(corpus, n):\n",
    "    ngrams_count = defaultdict(Counter)  # 创建一个字典，存储 N-Gram 计数\n",
    "    for text in corpus:  # 遍历语料库中的每个文本\n",
    "        tokens = tokenize(text)  # 对文本进行分词\n",
    "        for i in range(len(tokens) - n + 1):  # 遍历分词结果，生成 N-Gram\n",
    "            ngram = tuple(tokens[i:i+n])  # 创建一个 N-Gram 元组\n",
    "            prefix = ngram[:-1]  # 获取 N-Gram 的前缀\n",
    "            token = ngram[-1]  # 获取 N-Gram 的目标单字\n",
    "            ngrams_count[prefix][token] += 1  # 更新 N-Gram 计数\n",
    "    return ngrams_count\n",
    "bigram_counts = count_ngrams(corpus, 2) # 计算 bigram 词频\n",
    "print(\"bigram 词频：\") # 打印 bigram 词频\n",
    "for prefix, counts in bigram_counts.items():\n",
    "    print(\"{}: {}\".format(\"\".join(prefix), dict(counts))) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "06678ebf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "bigram 出现的概率 :\n",
      "我: {'喜': 1.0}\n",
      "喜: {'欢': 1.0}\n",
      "欢: {'吃': 1.0}\n",
      "吃: {'苹': 0.3333333333333333, '香': 0.3333333333333333, '葡': 0.16666666666666666, '草': 0.16666666666666666}\n",
      "苹: {'果': 1.0}\n",
      "香: {'蕉': 1.0}\n",
      "她: {'喜': 1.0}\n",
      "葡: {'萄': 1.0}\n",
      "他: {'不': 0.5, '喜': 0.5}\n",
      "不: {'喜': 1.0}\n",
      "草: {'莓': 1.0}\n"
     ]
    }
   ],
   "source": [
    "# 定义计算 N-Gram 出现概率的函数\n",
    "def ngram_probabilities(ngram_counts):\n",
    " ngram_probs = defaultdict(Counter) # 创建一个字典，存储 N-Gram 出现的概率\n",
    " for prefix, tokens_count in ngram_counts.items(): # 遍历 N-Gram 前缀\n",
    "     total_count = sum(tokens_count.values()) # 计算当前前缀的 N-Gram 计数\n",
    "     for token, count in tokens_count.items(): # 遍历每个前缀的 N-Gram\n",
    "         ngram_probs[prefix][token] = count / total_count # 计算每个 N-Gram 出现的概率\n",
    " return ngram_probs\n",
    "bigram_probs = ngram_probabilities(bigram_counts) # 计算 bigram 出现的概率\n",
    "print(\"\\nbigram 出现的概率 :\") # 打印 bigram 概率\n",
    "for prefix, probs in bigram_probs.items():\n",
    " print(\"{}: {}\".format(\"\".join(prefix), dict(probs)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fe633e97",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义生成下一个词的函数\n",
    "def generate_next_token(prefix, ngram_probs):\n",
    " if not prefix in ngram_probs: # 如果前缀不在 N-Gram 中，返回 None\n",
    "    return None\n",
    " next_token_probs = ngram_probs[prefix] # 获取当前前缀的下一个词的概率\n",
    " next_token = max(next_token_probs, \n",
    "                    key=next_token_probs.get) # 选择概率最大的词作为下一个词\n",
    " return next_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "12da246f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义生成连续文本的函数\n",
    "def generate_text(prefix, ngram_probs, n, length=6):\n",
    " tokens = list(prefix) # 将前缀转换为字符列表\n",
    " for _ in range(length - len(prefix)): # 根据指定长度生成文本 \n",
    "     # 获取当前前缀的下一个词\n",
    "     next_token = generate_next_token(tuple(tokens[-(n-1):]), ngram_probs) \n",
    "     if not next_token: # 如果下一个词为 None，跳出循环\n",
    "         break\n",
    "     tokens.append(next_token) # 将下一个词添加到生成的文本中\n",
    " return \"\".join(tokens) # 将字符列表连接成字符串"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7af72d40",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " 生成的文本： 我喜欢吃苹果\n"
     ]
    }
   ],
   "source": [
    "# 输入一个前缀，生成文本\n",
    "generated_text = generate_text(\"我\", bigram_probs, 2)\n",
    "print(\"\\n 生成的文本：\", generated_text) # 打印生成的文本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d3ed012",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
